import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from cbml_benchmark.modeling.registry import HEADS
from cbml_benchmark.utils.init_methods import weights_init_classifier,weights_init_kaiming


@HEADS.register('linear_norm')
class LinearNorm(nn.Module):
    def __init__(self, cfg, in_channels):
        super(LinearNorm, self).__init__()
        self.fc = nn.Linear(in_channels, cfg.MODEL.HEAD.DIM)
        self.fc.apply(weights_init_kaiming)

    def forward(self, x):
        x = self.fc(x)
        x = F.normalize(x, p=2, dim=1)
        return x


@HEADS.register('linear_norm_map')
class LinearNormMap(nn.Module):
    def __init__(self, in_channels, embedding_dim,num_mat):
        super(LinearNormMap, self).__init__()
        self.fc = nn.ModuleList([nn.Linear((i+1-i)*in_channels,embedding_dim) for i in range(num_mat)])
        for i in range(num_mat):
            self.fc[i].apply(weights_init_kaiming)
        self.dim = embedding_dim
        self.num_mat = num_mat

    def forward(self, x):
        N,C,H,W = x.size()
        x = x.view(N,C,-1).permute(0,2,1)
        x1 = torch.zeros([N,H*W,self.dim]).to('cuda')
        x = F.normalize(x, p=2, dim=2)
        for i in range(self.num_mat):
            x1[:,i,:] = self.fc[i](x[:,i,:])
        x = F.gelu(x1)
        x = F.normalize(x, p=2, dim=2)
        x = x.permute(0,2,1).view(N,self.dim,H,W)

        del x1
        return x


